"""Torch Module for Graph Isomorphism Network layer"""
# Portions of this code are inspired by the DGL (Deep Graph Library) project.
# See https://github.com/dmlc/dgl for reference implementations and license.
import torch as th
from torch import nn
import torch.nn.functional as F

import dgl.function as fn
from dgl.utils import expand_as_pair
import dgl
import copy
import time
import unittest


class GINConv(nn.Module):
    r"""Graph Isomorphism Network layer from `How Powerful are Graph
    Neural Networks? <https://arxiv.org/pdf/1810.00826.pdf>`__

    .. math::
        h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
        \mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i)
        \right\}\right)\right)

    If a weight tensor on each edge is provided, the weighted graph convolution is defined as:

    .. math::
        h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} +
        \mathrm{aggregate}\left(\left\{e_{ji} h_j^{l}, j\in\mathcal{N}(i)
        \right\}\right)\right)

    where :math:`e_{ji}` is the weight on the edge from node :math:`j` to node :math:`i`.
    Please make sure that `e_{ji}` is broadcastable with `h_j^{l}`.

    Parameters
    ----------
    apply_func : callable activation function/layer or None
        If not None, apply this function to the updated node feature,
        the :math:`f_\Theta` in the formula, default: None.
    aggregator_type : str
        Aggregator type to use (``sum``, ``max`` or ``mean``), default: 'sum'.
    init_eps : float, optional
        Initial :math:`\epsilon` value, default: ``0``.
    learn_eps : bool, optional
        If True, :math:`\epsilon` will be a learnable parameter. Default: ``False``.
    activation : callable activation function/layer or None, optional
        If not None, applies an activation function to the updated node features.
        Default: ``None``.

    Examples
    --------
    >>> import dgl
    >>> import numpy as np
    >>> import torch as th
    >>> from dgl.nn import GINConv
    >>>
    >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
    >>> feat = th.ones(6, 10)
    >>> lin = th.nn.Linear(10, 10)
    >>> conv = GINConv(lin, 'max')
    >>> res = conv(g, feat)
    >>> res
    tensor([[-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,
            0.8843, -0.8764],
            [-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,
            0.8843, -0.8764],
            [-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,
            0.8843, -0.8764],
            [-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,
            0.8843, -0.8764],
            [-0.4821,  0.0207, -0.7665,  0.5721, -0.4682, -0.2134, -0.5236,  1.2855,
            0.8843, -0.8764],
            [-0.1804,  0.0758, -0.5159,  0.3569, -0.1408, -0.1395, -0.2387,  0.7773,
            0.5266, -0.4465]], grad_fn=<AddmmBackward>)

    >>> # With activation
    >>> from torch.nn.functional import relu
    >>> conv = GINConv(lin, 'max', activation=relu)
    >>> res = conv(g, feat)
    >>> res
    tensor([[5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
             0.0000],
            [5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
             0.0000],
            [5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
             0.0000],
            [5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
             0.0000],
            [5.0118, 0.0000, 0.0000, 3.9091, 1.3371, 0.0000, 0.0000, 0.0000, 0.0000,
             0.0000],
            [2.5011, 0.0000, 0.0089, 2.0541, 0.8262, 0.0000, 0.0000, 0.1371, 0.0000,
             0.0000]], grad_fn=<ReluBackward0>)
    """

    def __init__(
        self,
        apply_func=None,
        aggregator_type="sum",
        init_eps=0,
        learn_eps=False,
        activation=None,
    ):
        super(GINConv, self).__init__()
        self.apply_func = apply_func
        self._aggregator_type = aggregator_type
        self.activation = activation
        if aggregator_type not in ("sum", "max", "mean"):
            raise KeyError(
                "Aggregator type {} not recognized.".format(aggregator_type)
            )
        # to specify whether eps is trainable or not.
        if learn_eps:
            self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
        else:
            self.register_buffer("eps", th.FloatTensor([init_eps]))

    def forward(self, graph, feat, edge_weight=None):
        r"""

        Description
        -----------
        Compute Graph Isomorphism Network layer.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        feat : torch.Tensor or pair of torch.Tensor
            If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
            :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
            If a pair of torch.Tensor is given, the pair must contain two tensors of shape
            :math:`(N_{in}, D_{in})` and :math:`(N_{out}, D_{in})`.
            If ``apply_func`` is not None, :math:`D_{in}` should
            fit the input dimensionality requirement of ``apply_func``.
        edge_weight : torch.Tensor, optional
            Optional tensor on the edge. If given, the convolution will weight
            with regard to the message.

        Returns
        -------
        torch.Tensor
            The output feature of shape :math:`(N, D_{out})` where
            :math:`D_{out}` is the output dimensionality of ``apply_func``.
            If ``apply_func`` is None, :math:`D_{out}` should be the same
            as input dimensionality.
        """
        _reducer = getattr(fn, self._aggregator_type)
        with graph.local_scope():
            aggregate_fn = fn.copy_u("h", "m")
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.num_edges()
                graph.edata["_edge_weight"] = edge_weight
                aggregate_fn = fn.u_mul_e("h", "_edge_weight", "m")

            feat_src, feat_dst = expand_as_pair(feat, graph)
            graph.srcdata["h"] = feat_src
            graph.update_all(aggregate_fn, _reducer("m", "neigh"))
            rst = (1 + self.eps) * feat_dst + graph.dstdata["neigh"]
            if self.apply_func is not None:
                rst = self.apply_func(rst)
            # activation
            if self.activation is not None:
                rst = self.activation(rst)
            return rst


class MLP(nn.Module):
    """Two-layer MLP aggregator with BatchNorm for GIN model"""

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linears = nn.ModuleList()
        # two-layer MLP
        self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
        self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
        self.batch_norm = nn.BatchNorm1d(hidden_dim)

    def forward(self, x):
        h = x
        h = F.relu(self.batch_norm(self.linears[0](h)))
        return self.linears[1](h)


class CustomGINConv(nn.Module):
    """Custom GIN convolution layer for federated learning settings that handles local and remote blocks.
    
    Similar to the original GIN, but processes local and remote blocks separately.
    
    Parameters
    ----------
    apply_func : callable activation function/layer or None
        If not None, apply this function to the updated node feature.
    aggregator_type : str
        Aggregator type to use (``sum``, ``max`` or ``mean``), default: 'sum'.
    init_eps : float, optional
        Initial epsilon value, default: 0.
    learn_eps : bool, optional
        If True, epsilon will be a learnable parameter, default: False.
    activation : callable, optional
        Activation function to use, default: None.
    """
    def __init__(
        self,
        apply_func=None,
        aggregator_type="sum",
        init_eps=0,
        learn_eps=False,
        activation=None,
    ):
        super(CustomGINConv, self).__init__()
        self.apply_func = apply_func
        self._aggregator_type = aggregator_type
        self.activation = activation
        if aggregator_type not in ("sum", "max", "mean"):
            raise KeyError(
                "Aggregator type {} not recognized.".format(aggregator_type)
            )
        # to specify whether eps is trainable or not.
        if learn_eps:
            self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
        else:
            self.register_buffer("eps", th.FloatTensor([init_eps]))

    def forward(self, local_block, remote_block, local_feat_src, remote_feat_src):
        """
        Forward computation that handles local and remote blocks separately.
        
        Parameters
        ----------
        local_block : DGLBlock
            The local block.
        remote_block : DGLBlock
            The remote block.
        local_feat_src : torch.Tensor
            Features of the source nodes in the local block.
        remote_feat_src : torch.Tensor
            Features of the source nodes in the remote block.
            
        Returns
        -------
        torch.Tensor
            The output features after GIN convolution.
        """
        _reducer = getattr(fn, self._aggregator_type)
        
        # Process local block (client-side computation)
        with local_block.local_scope():
            local_aggregate_fn = fn.copy_u("h", "m")
            local_src, local_dst = expand_as_pair(local_feat_src, local_block)
            local_block.srcdata["h"] = local_src
            local_block.update_all(local_aggregate_fn, _reducer("m", "neigh"))
            local_neigh = local_block.dstdata["neigh"]
            local_self = local_dst
            local_rst = (1 + self.eps) * local_self + local_neigh
        
        # Process remote block (server-side computation)
        with remote_block.local_scope():
            remote_aggregate_fn = fn.copy_u("h", "m")
            remote_src, remote_dst = expand_as_pair(remote_feat_src, remote_block)
            remote_block.srcdata["h"] = remote_src
            remote_block.update_all(remote_aggregate_fn, _reducer("m", "neigh"))
            remote_neigh = remote_block.dstdata["neigh"]
            remote_self = remote_dst
            remote_rst = (1 + self.eps) * remote_self + remote_neigh

        combined_rst = th.cat([local_rst, remote_rst], dim=0)
        
        # Apply MLP/activation function
        if self.apply_func is not None:
            combined_rst = self.apply_func(combined_rst)
        
        # Apply activation
        if self.activation is not None:
            combined_rst = self.activation(combined_rst)
            
        return combined_rst

class ClientServerGIN(nn.Module):
    """GIN model for client-server federated learning setting, following the official implementation pattern.
    
    Parameters
    ----------
    in_feats : int
        Input feature size.
    n_hidden : int
        Hidden layer size.
    n_classes : int
        Number of output classes.
    n_layers : int
        Number of GIN layers.
    activation : callable
        Activation function.
    dropout : float
        Dropout rate.
    aggregator_type : str
        Aggregator type, default: 'sum'.
    init_eps : float
        Initial epsilon value, default: 0.
    learn_eps : bool
        Whether to learn epsilon, default: False.
    """
    def __init__(self, in_feats, n_hidden, n_classes, n_layers,
                 activation=F.relu, dropout=0.5, aggregator_type='sum',
                 init_eps=0, learn_eps=False, use_batch_norm=True):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.use_batch_norm = use_batch_norm
        
        self.ginlayers = nn.ModuleList()
        self.batch_norms = nn.ModuleList() if use_batch_norm else None
        
        # Create MLPs and GIN layers
        for layer in range(n_layers):
            if layer == 0:
                mlp = MLP(in_feats, n_hidden, n_hidden)
            else:
                mlp = MLP(n_hidden, n_hidden, n_hidden)
                
            self.ginlayers.append(
                CustomGINConv(mlp, aggregator_type, init_eps, learn_eps)
            )
            
            if use_batch_norm:
                self.batch_norms.append(nn.BatchNorm1d(n_hidden))
        
        # Optional: Linear predictors for each layer (for layer-wise readout)
        self.linear_prediction = nn.ModuleList()
        for layer in range(n_layers + 1):  # +1 for input layer
            if layer == 0:
                self.linear_prediction.append(nn.Linear(in_feats, n_classes))
            else:
                self.linear_prediction.append(nn.Linear(n_hidden, n_classes))
        
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, local_feat_src, remote_feat_src, transform_indices):
        """
        Forward computation with local and remote blocks.
        
        Parameters
        ----------
        blocks : list of tuples
            List of (local_block, remote_block) tuples for each layer.
        local_feat_src : torch.Tensor
            Features of the source nodes in the first local block.
        remote_feat_src : torch.Tensor
            Features of the source nodes in the first remote block.
        transform_indices : list of tuples
            List of (local_indices, remote_indices) for transforming outputs between layers.
            
        Returns
        -------
        torch.Tensor
            The output predictions.
        """
        h_local = local_feat_src
        h_remote = remote_feat_src
        
        for l, (layer, block) in enumerate(zip(self.ginlayers, blocks)):
            local_block, remote_block = block
            assert h_local.shape[0] == local_block.number_of_src_nodes()
            assert h_remote.shape[0] == remote_block.number_of_src_nodes()
            
            h = layer(local_block, remote_block, h_local, h_remote)
            
            if self.use_batch_norm:
                h = self.batch_norms[l](h)
            
            h = self.activation(h)
            
            if l != len(self.ginlayers) - 1:
                # Transform for the next layer
                local_indices, remote_indices = transform_indices[l]
                h_local = h[local_indices]
                h_remote = h[remote_indices]
                
                assert h_local.shape[0] == len(local_indices)
                assert h_remote.shape[0] == len(remote_indices)
                assert h_local.shape[0] + h_remote.shape[0] == h.shape[0]
        
        return h

def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    labels = labels.long()
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)


def evaluate(model, g, valid_sampler, device, valid_nid, args):
    """
    Evaluate the model on the validation set specified by ``val_nid``.
    """
    valid_shuffle_index = th.randperm(len(valid_nid))
    valid_shuffle_nids = copy.deepcopy(valid_nid)
    valid_shuffle_nids = valid_shuffle_nids[valid_shuffle_index]
    valid_shuffle_nids = valid_shuffle_nids[:args.valid_batch_size]
    
    seeds = valid_shuffle_nids
    blocks, transform_indices = valid_sampler.sample(g, seeds)
    # Moving to the device
    for i in range(len(blocks)):
        blocks[i] = (blocks[i][0].to(device), blocks[i][1].to(device))

    # Extract features
    first_local_block, first_remote_block = blocks[0]
    local_feat_src = first_local_block.srcdata['features'].to(device)
    remote_feat_src = first_remote_block.srcdata['features'].to(device)

    # Extract labels
    last_local_block, last_remote_block = blocks[-1]
    assert last_remote_block.number_of_edges() == 0 # The last remote block should be empty
    labels = last_local_block.dstdata['labels'].to(device)

    model.eval()
    with th.no_grad():
        batch_pred = model(blocks, local_feat_src, remote_feat_src, transform_indices)
    model.train()
    return compute_acc(batch_pred, labels)
        

class TestCustomGINConv(unittest.TestCase):
    def test_forward(self):
        # Create a larger graph
        g = dgl.graph([(0, 2), (0, 3), (1, 2), (1, 4), (2, 5), (3, 4), (3, 6), (4, 5), (4, 6), (5, 6)])
        
        # Sample neighbors for nodes 3, 4, and 5 (remote nodes)
        remote_nodes = th.tensor([3, 4, 5])
        remote_frontier = dgl.sampling.sample_neighbors(g, remote_nodes, fanout=5)
        
        # Create a remote block from the frontier
        remote_block = dgl.to_block(remote_frontier, remote_nodes)
        
        # Sample neighbors for nodes 2 and 6 (local nodes)
        local_nodes = th.tensor([2, 6])
        local_frontier = dgl.sampling.sample_neighbors(g, local_nodes, fanout=5)
        
        # Create a local block from the frontier
        local_block = dgl.to_block(local_frontier, local_nodes)
        
        # Assign node features
        in_feats = 16
        g.ndata['feat'] = th.randn(g.num_nodes(), in_feats)
        
        # Extract features for remote_block and local_block
        remote_feat = g.ndata['feat'][remote_block.srcdata[dgl.NID]]
        local_feat = g.ndata['feat'][local_block.srcdata[dgl.NID]]
        
        # Initialize the CustomGINConv layer
        out_feats = 32
        activation = th.nn.functional.relu
        
        # Create an MLP as the apply_func
        apply_func = nn.Sequential(
            nn.Linear(in_feats, out_feats),
            nn.ReLU(),
            nn.Linear(out_feats, out_feats)
        )
        
        # Create custom GIN layer
        conv = CustomGINConv(apply_func, 'sum', init_eps=0.0, activation=activation)
        print("CustomGINConv:", conv)
        
        # Perform forward pass
        h = conv(local_block, remote_block, local_feat, remote_feat)
        print("Custom GIN output:", h)

        # Check the output shape
        expected_shape = (local_block.num_dst_nodes() + remote_block.num_dst_nodes(), out_feats)
        self.assertEqual(h.shape, expected_shape)
        
        # Check if the activation is applied correctly
        self.assertTrue(th.all(h >= 0))

        # Compare with the official GINConv implementation
        official_conv = GINConv(apply_func, aggregator_type='sum', init_eps=0.0, activation=activation)
        
        # Apply to the whole graph
        official_h = official_conv(g, g.ndata['feat'])
        
        # Extract the relevant nodes
        official_h_local = official_h[local_block.dstdata[dgl.NID]]
        official_h_remote = official_h[remote_block.dstdata[dgl.NID]]
        official_h_combined = th.cat([official_h_local, official_h_remote], dim=0)
        print("Official GIN output:", official_h_combined)
        
        # Results may not match exactly due to different ordering and computation,
        # but should be close if the implementation is correct
        self.assertTrue(th.allclose(h, official_h_combined, rtol=1e-3, atol=1e-3))


if __name__ == "__main__":
    # Run the CustomGINConv test
    print("\nRunning CustomGINConv test:")
    test_suite = unittest.TestLoader().loadTestsFromTestCase(TestCustomGINConv)
    unittest.TextTestRunner(verbosity=2).run(test_suite)